import os
from pathlib import Path
from tree_sitter import Language, Parser


# ----------------- Build tree-sitter libraries -----------------#
curdir = Path(__file__).parent.absolute()
BASEPATH = os.path.abspath(os.path.join(curdir, "../../_deps/tree-sitter"))
LANGS = ["java", "python", "c", "c-sharp", "go", "javascript", "ruby", "php"]
PARSER_PATH = [f"{BASEPATH}/tree-sitter-{lang}" for lang in LANGS]
OUTPATH = f"{BASEPATH}/_build/my-languages.so"

Language.build_library(OUTPATH, PARSER_PATH)

LEAF_NODE_TYPES = {
    "python": [
        "identifier",
        "integer",
        "string",
        "comment",
        "float",
        "ERROR",
        "MISSING",
    ],
    "java": [
        "identifier",
        "scoped_identifier",
        "type_identifier",
        "scoped_type_identifier",
        "line_comment",
        "block_comment",
        "hex_integer_literal",
        "decimal_integer_literal",
        "octal_integer_literal",
        "binary_integer_literal",
        "decimal_floating_point_literal",
        "hex_floating_point_literal",
        "character_literal",
        "string_literal",
        "text_block",
        "ERROR",
        "MISSING",
    ],
    "c": [
        "comment",
        "identifier",
        "type_identifier",
        "statement_identifier",
        "field_identifier",
        "string_literal",
        "number_literal",
        "char_literal",
        "system_lib_string",
        "preproc_arg",
        "preproc_directive",
        "ERROR",
        "MISSING",
    ],
    "c_sharp": [
        "identifier",
        "comment",
        "integer_literal",
        "boolean_literal",
        "null_literal",
        "character_literal",
        "real_literal",
        "string_literal",
        "verbatim_string_literal",
        "raw_string_literal",
        "preproc_integer_literal",
        "preproc_string_literal",
        "preproc_message",
        "interpolated_string_text",
        "interpolated_verbatim_string_text",
        "interpolated_raw_string_text",
        "interpolated_string_text_fragment",
        "ERROR",
        "MISSING",
    ],
    "php": [
        "identifier",
        "integer",
        "float",
        "boolean",
        "null",
        "comment",
        "name",
        "variable_name",
        "string",
        "escape_sequence",
        "namespace_name",
        "encapsed_string",
        "text",
        "heredoc",
        "nowdoc",
        "shell_command_expression",
        "ERROR",
        "MISSING",
    ],
    "go": [
        "identifier",
        "comment",
        "string_literal",
        "int_literal",
        "float_literal",
        "package_identifier",
        "interpreted_string_literal",
        "type_identifier",
        "field_identifier",
        "escape_sequence",
        "raw_string_literal",
        "rune_literal",
        "imaginary_literal",
        "name",
        "label_name",
        "ERROR",
        "MISSING",
    ],
    "javascript": [
        "identifier",
        "string",
        "comment",
        "number",
        "property_identifier",
        "statement_identifier",
        "shorthand_property_identifier",
        "regex_pattern",
        "shorthand_property_identifier_pattern",
        "jsx_text",
        "hash_bang_line",
        "hex_literal",
        "decimal_digits",
        "binary_literal",
        "octal_literal",
        "bigint_literal",
        "decimal_integer_literal",
        "ERROR",
        "MISSING",
        "template_string",
    ],
    "ruby": [
        "identifier",
        "constant",
        "string_content",
        "comment",
        "simple_symbol",
        "hash_key_symbol",
        "false",
        "true",
        "integer",
        "float",
        "complex",
        "rational",
        "heredoc_content",
        "string",
        "uninterpreted",
        "global_variable",
        "string_content",
        "nil",
        "heredoc_body",
        "heredoc_beginning",
        "instance_variable",
        "class_variable",
        "string_array",
        "regex",
        "escape_sequence",
        "subshell",
        "ERROR",
        "MISSING",
    ],
}

PYTHON_NEWLINE_STATEMENTS = [
    "future_import_statement",
    "import_statement",
    "import_from_statement",
    "print_statement",
    "assert_statement",
    "expression_statement",
    "return_statement",
    "delete_statement",
    "raise_statement",
    "pass_statement",
    "break_statement",
    "continue_statement",
    "global_statement",
    "nonlocal_statement",
    "exec_statement",
]

LANG_SPECIAL_TOKENS = {
    "python": ["indent", "dedent", "newline"],
    "ruby": ["line_break"],
}


TOK_TREESTART = "[BOT]"
TOK_TREEEND = "[EOT]"
TOK_SPACE = " "
NODE_SPACE = "\u2581"  # \u2581 is ▁ and is used to replace spaces in node texts
NODE_START_TOK = "(_."
NODE_END_TOK = "._)"

# ----------------- Node datastructure -----------------#


class Node(object):
    def __init__(self, text, children=None, is_bpe=False):
        self.text = text
        self.children = children
        self.is_bpe = is_bpe


# ----------------- Tree utility functions  -----------------#


def _add_java_class_(code):
    code = "public class Test {\n" + code + "\n}"
    return code


def _add_php_class_(code):
    code = "<?php\nclass Test{\n" + code + "\n}"
    return code


# def _get_java_method_node_(root_node):
#     if root_node is None:
#         return None

#     if root_node.type == "method_declaration":
#         return root_node

#     for childnode in root_node.children:
#         res = _get_java_method_node_(childnode)
#         if res is not None:
#             return res

#     return None


def _get_java_method_node_(root_node):
    if root_node is None:
        return None

    if root_node.type == "class_body":
        class_body_children = root_node.children
        return class_body_children[1]

    for childnode in root_node.children:
        res = _get_java_method_node_(childnode)
        if res is not None:
            return res

    return None


# TODO: Update the function to find methods more generally
def _get_php_method_node_(root_node):
    if root_node is None:
        return None

    if root_node.type == "method_declaration":
        return root_node

    for childnode in root_node.children:
        res = _get_php_method_node_(childnode)
        if res is not None:
            return res

    return None


def create_TS_tree(code, language, add_cls):
    """Create Tree sitter Tree and return root node of the CST"""

    lang = Language(OUTPATH, language)
    parser = Parser()
    parser.set_language(lang)

    if language == "java" and add_cls:
        code = _add_java_class_(code)
    elif language == "php" and add_cls:
        code = _add_php_class_(code)

    tree = parser.parse(bytes(code, "utf8"))

    if language == "java" and add_cls:
        root = _get_java_method_node_(tree.root_node)
    elif language == "php" and add_cls:
        root = _get_php_method_node_(tree.root_node)
    else:
        root = tree.root_node

    return root


def create_custom_tree(TS_node, lang, raise_on_error=False):
    """Create custom tree given tree-sitter tree"""

    if TS_node is None:
        return None

    if TS_node.has_error and raise_on_error:
        raise AssertionError("Code cannot be parsed.")

    LANG_LEAF_NODE_TYPES = LEAF_NODE_TYPES[lang]

    nodetype = TS_node.type
    try:
        nodetype = nodetype.decode("utf-8")
    except (UnicodeDecodeError, AttributeError):
        pass

    nodetext = TS_node.text
    try:
        nodetext = nodetext.decode("utf-8")
    except (UnicodeDecodeError, AttributeError):
        pass

    if nodetype in LANG_LEAF_NODE_TYPES:
        child_node = Node(nodetext, is_bpe=True)
        node = Node(nodetype, children=[child_node], is_bpe=False)
        return node

    # If leaf node is passed, convert to Leaf object with text and return
    if TS_node.child_count == 0:
        # If leaf node is INDENT/DEDENT or other special python token
        if lang in LANG_SPECIAL_TOKENS and nodetype in LANG_SPECIAL_TOKENS[lang]:
            node = Node(nodetype)
        else:
            node = Node(nodetext)

        return node

    # Last possibility is intermediate node
    children = [create_custom_tree(ts_child_node, lang) for ts_child_node in TS_node.children]
    node = Node(nodetype, children)

    return node


def print_TS_tree(node):
    """Print the tree"""

    def _print_(node, level=0):
        indent = "\t" * level
        print(indent + str(node.type))

        if node.children is None or len(node.children) == 0:
            return

        for children in node.children:
            _print_(children, level=level + 1)

    _print_(node)


def print_tree(node):
    """Print the tree"""

    def _print_(node, level=0):
        indent = "\t" * level
        print(indent + str(node.text) + " :: " + str(node.is_bpe))

        if node.children is None or len(node.children) == 0:
            return

        for children in node.children:
            _print_(children, level=level + 1)

    _print_(node)


def create_tokens_dict(root_node, tokendict, reverse_tokendict, rowidx, traversal_type, add_bpetoks):
    def _traverse_tree_(node):
        if node is None:
            return

        text = node.text
        nchild = 0 if node.children is None else len(node.children)
        is_bpe = node.is_bpe

        if text not in reverse_tokendict:
            reverse_tokendict[text] = {"bpe": [], "nonbpe": []}

        if is_bpe:
            if add_bpetoks:
                tokendict["bpe"].add(text)
            reverse_tokendict[text]["bpe"].append(rowidx)
        else:
            if _add_start_end_toks and nchild > 0:
                tokendict["nonbpe"].add(f"{NODE_START_TOK}{text}")
                tokendict["nonbpe"].add(f"{text}{NODE_END_TOK}")
                reverse_tokendict[text]["nonbpe"].append(rowidx)
            else:
                tokendict["nonbpe"].add(text)
                reverse_tokendict[text]["nonbpe"].append(rowidx)

        if node.children is None or len(node.children) == 0:
            return

        for childnode in node.children:
            _traverse_tree_(childnode)

    if traversal_type == "preorder_dfs_nodeleaf_toks":
        _add_start_end_toks = True
    else:
        _add_start_end_toks = False

    _traverse_tree_(root_node)
    return tokendict, reverse_tokendict


def count_nodes(node):
    if node is None:
        return 0, 0

    if node.children is None:
        child_counts = [(0, 0)]
    else:
        child_counts = [count_nodes(childnode) for childnode in node.children]

    child_bpe = sum([x[0] for x in child_counts])
    child_nonbpe = sum([x[1] for x in child_counts])

    if node.is_bpe:
        child_bpe += 1
    else:
        child_nonbpe += 1

    return child_bpe, child_nonbpe


# ----------------- Tree serialization functions  -----------------#


def serialize(node, type):
    if type == "preorder_dfs":
        treelist = _serialize_preorder_dfs(node)
    elif type == "preorder_dfs_nodeleaf_toks":
        treelist = _serialize_preorder_dfs_nodeleaf_toks_(node)

    return TOK_SPACE.join(treelist)


def _serialize_preorder_dfs(node):
    if node is None:
        return ""

    treestr = []
    text = node.text
    text = text.replace(" ", NODE_SPACE)

    if type(text) == bytes:
        text = text.decode("utf-8")

    treestr.extend([text])
    # treestr.extend([TOK_TREESTART])

    if node.children is not None:
        for childnode in node.children:
            treestr.extend(_serialize_preorder_dfs(childnode))

    treestr.extend([TOK_TREEEND])
    return treestr


def _serialize_preorder_dfs_nodeleaf_toks_(node):
    if node is None:
        return ""

    treestr = []
    text = node.text
    text = text.replace(" ", NODE_SPACE)

    # convert byte nodes to unicode
    try:
        text = text.decode("utf-8")
    except (UnicodeDecodeError, AttributeError):
        pass

    nchild = 0 if node.children is None else len(node.children)

    if nchild == 0:
        treestr.extend([text])
    else:
        treestr.extend([f"{NODE_START_TOK}{text}"])

    if node.children is not None:
        for childnode in node.children:
            treestr.extend(_serialize_preorder_dfs_nodeleaf_toks_(childnode))

    if nchild > 0:
        treestr.extend([f"{text}{NODE_END_TOK}"])

    return treestr


# ----------------- Tree de-serialization functions  -----------------#


def deserialize(treestr, type, special_tokens=None):
    treestr = treestr.strip()

    if type == "preorder_dfs":
        root_node = _deserialize_preorder_dfs(treestr, special_tokens)
    elif type == "preorder_dfs_nodeleaf_toks":
        root_node = _deserialize_preorder_dfs_nodeleaf_toks_(treestr, special_tokens)

    return root_node


def _deserialize_preorder_dfs(treestr, special_tokens):
    def create_tree():
        if nodelist is None or len(nodelist) == 0:
            return None

        nodestr = nodelist.pop(0)
        nodestr = nodestr.replace(NODE_SPACE, " ")

        if nodestr == TOK_TREEEND:
            return None

        node = Node(nodestr, children=[])
        childnodes = []
        children = create_tree()

        while children:
            childnodes.append(children)
            children = create_tree()

        if childnodes:
            node.children = childnodes

        return node

    nodelist = treestr.split(TOK_SPACE)

    if special_tokens is not None:
        nodelist = [tok for tok in nodelist if tok not in special_tokens]

    root = create_tree()
    return root


def _deserialize_preorder_dfs_nodeleaf_toks_(treestr, special_tokens):
    def create_tree():
        if nodelist is None or len(nodelist) == 0:
            return None

        nodestr = nodelist.pop(0)
        nodestr = nodestr.replace(NODE_SPACE, " ")

        if nodestr.endswith(NODE_END_TOK):
            return None

        elif nodestr.startswith(NODE_START_TOK):
            nodestr = nodestr.replace(NODE_START_TOK, "")
            node = Node(nodestr, children=[])
            childnodes = []
            children = create_tree()

            while children:
                childnodes.append(children)
                children = create_tree()

            if childnodes:
                node.children = childnodes

        else:
            node = Node(nodestr)

        return node

    nodelist = treestr.split(TOK_SPACE)

    if special_tokens is not None:
        nodelist = [tok for tok in nodelist if tok not in special_tokens]

    root = create_tree()
    return root


# ----------------- Tree to code functions  -----------------#


def convert_tree_to_code_generic(root, lang):
    if root is None:
        return ""

    if root.children is None or len(root.children) == 0:
        return root.text

    suffix = ""
    if "comment" in root.text:
        if lang == "python":
            suffix = " newline"
        elif lang == "comment":
            suffix = " line_break"
        else:
            suffix = " \n"

    text = " ".join([convert_tree_to_code_generic(childnode, lang) for childnode in root.children])
    text = text + suffix

    return text


def postprocess_python_code(code):
    indent_str = "  "
    indent_cnt = 0
    code_proc = []

    for token in code.split(" "):
        if token == "indent":
            indent_cnt += 1
            tokstr = "\n" + indent_str * indent_cnt
            code_proc.append(tokstr)
        elif token == "dedent":
            indent_cnt -= 1
            tokstr = "\n" + indent_str * indent_cnt
            code_proc.append(tokstr)
        elif token == "newline":
            tokstr = "\n" + indent_str * indent_cnt
            code_proc.append(tokstr)
        else:
            code_proc.append(token)

    code = " ".join(code_proc)
    return code


def postprocess_ruby_code(code):
    code_proc = []

    for token in code.split(" "):
        if token == "line_break":
            tokstr = "\n"
            code_proc.append(tokstr)
        else:
            code_proc.append(token)

    code = " ".join(code_proc)
    return code


def convert_tree_to_code(root, lang):
    code = convert_tree_to_code_generic(root, lang)

    if lang == "python":
        code = postprocess_python_code(code)
    elif lang == "ruby":
        code = postprocess_ruby_code(code)

    return code
